{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Notebook written by [Zhedong Zheng](https://github.com/zhedongzheng)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import tensorflow as tf\n",
    "import numpy as np\n",
    "import os\n",
    "from tqdm import tqdm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "BATCH_SIZE = 3\n",
    "VOCAB_SIZE = 20000\n",
    "TF_RECORD_PATH = './imdb_train_var_len.tfrecord'\n",
    "\n",
    "(X_train, y_train), (_, _) = tf.keras.datasets.imdb.load_data(num_words=VOCAB_SIZE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████████████████████| 25000/25000 [00:37<00:00, 670.52it/s]\n"
     ]
    }
   ],
   "source": [
    "if not os.path.isfile(TF_RECORD_PATH):\n",
    "    writer = tf.python_io.TFRecordWriter(TF_RECORD_PATH)\n",
    "    for sent, label in tqdm(zip(X_train, y_train), total=len(X_train), ncols=70):\n",
    "        example = tf.train.Example(\n",
    "            features = tf.train.Features(\n",
    "                 feature = {\n",
    "                   'sent': tf.train.Feature(\n",
    "                       int64_list=tf.train.Int64List(value=sent)),\n",
    "                   'label': tf.train.Feature(\n",
    "                       int64_list=tf.train.Int64List(value=[label])),\n",
    "                   }))\n",
    "        serialized = example.SerializeToString()\n",
    "        writer.write(serialized)\n",
    "    writer.close()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "(?, ?) (?,)\n"
     ]
    }
   ],
   "source": [
    "def _parse_fn(example_proto):\n",
    "    parsed_feats = tf.parse_single_example(\n",
    "        example_proto,\n",
    "        features={\n",
    "            'sent': tf.FixedLenSequenceFeature([], tf.int64, allow_missing=True),\n",
    "            'label': tf.FixedLenFeature([], tf.int64)\n",
    "        })\n",
    "    return parsed_feats\n",
    "\n",
    "dataset = tf.data.TFRecordDataset([TF_RECORD_PATH])\n",
    "dataset = dataset.map(_parse_fn)\n",
    "dataset = dataset.padded_batch(BATCH_SIZE, {'sent': [None], 'label': []})\n",
    "iterator = dataset.make_one_shot_iterator()\n",
    "batch = iterator.get_next()\n",
    "X_batch, y_batch = batch['sent'], batch['label']\n",
    "print(X_batch.get_shape(), y_batch.get_shape())"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[218, 189, 141] -> (3, 218)\n",
      "\n",
      "Original:\n",
      "[   1   14   47    8   30   31    7    4  249  108    7    4 5974   54\n",
      "   61  369   13   71  149   14   22  112    4 2401  311   12   16 3711\n",
      "   33   75   43 1829  296    4   86  320   35  534   19  263 4821 1301\n",
      "    4 1873   33   89   78   12   66   16    4  360    7    4   58  316\n",
      "  334   11    4 1716   43  645  662    8  257   85 1200   42 1228 2578\n",
      "   83   68 3912   15   36  165 1539  278   36   69    2  780    8  106\n",
      "   14 6905 1338   18    6   22   12  215   28  610   40    6   87  326\n",
      "   23 2300   21   23   22   12  272   40   57   31   11    4   22   47\n",
      "    6 2307   51    9  170   23  595  116  595 1352   13  191   79  638\n",
      "   89    2   14    9    8  106  607  624   35  534    6  227    7  129\n",
      "  113]\n",
      "\n",
      "Padded:\n",
      "[   1   14   47    8   30   31    7    4  249  108    7    4 5974   54\n",
      "   61  369   13   71  149   14   22  112    4 2401  311   12   16 3711\n",
      "   33   75   43 1829  296    4   86  320   35  534   19  263 4821 1301\n",
      "    4 1873   33   89   78   12   66   16    4  360    7    4   58  316\n",
      "  334   11    4 1716   43  645  662    8  257   85 1200   42 1228 2578\n",
      "   83   68 3912   15   36  165 1539  278   36   69    2  780    8  106\n",
      "   14 6905 1338   18    6   22   12  215   28  610   40    6   87  326\n",
      "   23 2300   21   23   22   12  272   40   57   31   11    4   22   47\n",
      "    6 2307   51    9  170   23  595  116  595 1352   13  191   79  638\n",
      "   89    2   14    9    8  106  607  624   35  534    6  227    7  129\n",
      "  113    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0    0    0    0    0    0    0\n",
      "    0    0    0    0    0    0    0    0]\n"
     ]
    }
   ],
   "source": [
    "sess = tf.Session()\n",
    "x, y = sess.run([X_batch, y_batch])\n",
    "print([len(x) for x in X_train[:BATCH_SIZE]], '->', x.shape)\n",
    "print()\n",
    "print(\"Original:\")\n",
    "print(np.array(X_train[BATCH_SIZE-1]))\n",
    "print()\n",
    "print(\"Padded:\")\n",
    "print(x[-1])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
